{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of Softmax Regression from Scratch\n",
    ":label:`sec_softmax_scratch`\n",
    "\n",
    "Just as we implemented linear regression from scratch,\n",
    "we believe that multiclass logistic (softmax) regression\n",
    "is similarly fundamental and you ought to know\n",
    "the gory details of how to implement it yourself.\n",
    "As with linear regression, after doing things by hand\n",
    "we will breeze through an implementation in DJL for comparison.\n",
    "To begin, let us import the familiar packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "1"
    }
   },
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils.ipynb\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/FashionMnistUtils.java\n",
    "%load ../utils/ImageUtils.java\n",
    "\n",
    "import ai.djl.basicdataset.cv.classification.FashionMnist;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will work with the Fashion-MNIST dataset, just introduced in :numref:`sec_fashion_mnist`,\n",
    "setting up an iterator with batch size $256$. We also set it to randomly shuffled elements for each batch for the training set. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "boolean randomShuffle = true;\n",
    "\n",
    "// get training and validation dataset\n",
    "FashionMnist trainingSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, randomShuffle)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "FashionMnist validationSet = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing Model Parameters\n",
    "\n",
    "As in our linear regression example,\n",
    "each example here will be represented by a fixed-length vector.\n",
    "Each example in the raw data is a $28 \\times 28$ image.\n",
    "In this section, we will flatten each image,\n",
    "treating them as $784$-long 1D vectors.\n",
    "In the future, we will talk about more sophisticated strategies\n",
    "for exploiting the spatial structure in images,\n",
    "but for now we treat each pixel location as just another feature.\n",
    "\n",
    "Recall that in softmax regression,\n",
    "we have as many outputs as there are categories.\n",
    "Because our dataset has $10$ categories,\n",
    "our network will have an output dimension of $10$.\n",
    "Consequently, our weights will constitute a $784 \\times 10$ matrix\n",
    "and the biases will constitute a $1 \\times 10$ vector.\n",
    "As with linear regression, we will initialize our weights $W$\n",
    "with Gaussian noise and our biases to take the initial value $0$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "int numInputs = 784;\n",
    "int numOutputs = 10;\n",
    "\n",
    "NDManager manager = NDManager.newBaseManager();\n",
    "NDArray W = manager.randomNormal(0, 0.01f, new Shape(numInputs, numOutputs), DataType.FLOAT32);\n",
    "NDArray b = manager.zeros(new Shape(numOutputs), DataType.FLOAT32);\n",
    "NDList params = new NDList(W, b);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Softmax\n",
    "\n",
    "Before implementing the softmax regression model,\n",
    "let us briefly review how operators such as `sum()` work\n",
    "along specific dimensions in an `NDArray`.\n",
    "Given a matrix `X` we can sum over all elements (default) or only\n",
    "over elements in the same axis, *i.e.*, the column (`new int[]{0}`) or the same row (`new int[]{1}`).\n",
    "We wrap the axis in an int array since we can specify multiple axes as well.\n",
    "For example if we call `sum()` with `new int[]{0, 1}`, it sums up the elements over both the rows and columns.\n",
    "In this 2D array, this means the total sum of the elements within! \n",
    "Note that if `X` is an array with shape `($2$, $3$)`\n",
    "and we sum over the columns (`X.sum(new int[]{0})`),\n",
    "the result will be a (1D) vector with shape `($3$,)`.\n",
    "If we want to keep the number of axes in the original array\n",
    "(resulting in a 2D array with shape `($1$, $3$)`),\n",
    "rather than collapsing out the dimension that we summed over\n",
    "we can specify `true` when invoking `sum()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "NDArray X = manager.create(new int[][]{{1, 2, 3}, {4, 5, 6}});\n",
    "System.out.println(X.sum(new int[]{0}, true));\n",
    "System.out.println(X.sum(new int[]{1}, true));\n",
    "System.out.println(X.sum(new int[]{0, 1}, true));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to implement the softmax function.\n",
    "Recall that softmax consists of two steps:\n",
    "First, we exponentiate each term (using `exp()`).\n",
    "Then, we sum over each row (we have one row per example in the batch)\n",
    "to get the normalization constants for each example.\n",
    "Finally, we divide each row by its normalization constant,\n",
    "ensuring that the result sums to $1$.\n",
    "Before looking at the code, let us recall\n",
    "how this looks expressed as an equation:\n",
    "\n",
    "$$\n",
    "\\mathrm{softmax}(\\mathbf{X})_{ij} = \\frac{\\exp(X_{ij})}{\\sum_k \\exp(X_{ik})}.\n",
    "$$\n",
    "\n",
    "The denominator, or normalization constant,\n",
    "is also sometimes called the partition function\n",
    "(and its logarithm is called the log-partition function).\n",
    "The origins of that name are in [statistical physics](https://en.wikipedia.org/wiki/Partition_function_(statistical_mechanics))\n",
    "where a related equation models the distribution\n",
    "over an ensemble of particles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [],
   "source": [
    "public NDArray softmax(NDArray X) {\n",
    "    NDArray Xexp = X.exp();\n",
    "    NDArray partition = Xexp.sum(new int[]{1}, true);\n",
    "    return Xexp.div(partition); // The broadcast mechanism is applied here\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, for any random input,\n",
    "we turn each element into a non-negative number.\n",
    "Moreover, each row sums up to 1,\n",
    "as is required for a probability.\n",
    "Note that while this looks correct mathematically,\n",
    "we were a bit sloppy in our implementation\n",
    "because we failed to take precautions against numerical overflow or underflow\n",
    "due to large (or very small) elements of the matrix,\n",
    "as we did in :numref:`sec_naive_bayes`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "7"
    }
   },
   "outputs": [],
   "source": [
    "NDArray X = manager.randomNormal(new Shape(2, 5));\n",
    "NDArray Xprob = softmax(X);\n",
    "System.out.println(Xprob);\n",
    "System.out.println(Xprob.sum(new int[]{1}));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Model\n",
    "\n",
    "Now that we have defined the softmax operation,\n",
    "we can implement the softmax regression model.\n",
    "The below code defines the forward pass through the network.\n",
    "Note that we flatten each original image in the batch\n",
    "into a vector with length `numInputs` with the `reshape()` function\n",
    "before passing the data through our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "8"
    }
   },
   "outputs": [],
   "source": [
    "// We need to wrap `net()` in a class so that we can reference the method\n",
    "// and pass it as a parameter to a function or save it in a variable\n",
    "public class Net {\n",
    "    public static NDArray net(NDArray X) {\n",
    "        NDArray currentW = params.get(0);\n",
    "        NDArray currentB = params.get(1);\n",
    "        return softmax(X.reshape(new Shape(-1, numInputs)).dot(currentW).add(currentB));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Loss Function\n",
    "\n",
    "Next, we need to implement the cross-entropy loss function,\n",
    "introduced in :numref:`sec_softmax`.\n",
    "This may be the most common loss function\n",
    "in all of deep learning because, at the moment,\n",
    "classification problems far outnumber regression problems.\n",
    "\n",
    "Recall that cross-entropy takes the negative log likelihood\n",
    "of the predicted probability assigned to the true label $-\\log P(y \\mid x)$.\n",
    "Rather than iterating over the predictions with a Java `for` loop\n",
    "(which tends to be inefficient),\n",
    "we can use the NDArray `get()` function in conjunction with `NDIndex`\n",
    "to let us easily select the appropriate terms\n",
    "from the matrix of softmax entries. This is typically known as the `pick()`\n",
    "operator in other frameworks such as `PyTorch`.\n",
    "Below, we illustrate the usage on a toy example,\n",
    "with $3$ categories and $2$ examples.\n",
    "\n",
    "The `\":, {}\"` section of the `NDIndex` selects all arrays\n",
    "and the `manager.create(new int[]{0, 2})` creates an\n",
    "`NDArray` with the values 0 and 2 to pick the 0th and 2nd elements\n",
    "for each respective `NDArray`.\n",
    "\n",
    "Note: when using `NDIndex` in this way, the passed in `NDArray` used for picking\n",
    "indices must be of type `int` or `long`. You can use the `toType()` function\n",
    "to change the type of the `NDArray` which will be shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "9"
    }
   },
   "outputs": [],
   "source": [
    "NDArray yHat = manager.create(new float[][]{{0.1f, 0.3f, 0.6f}, {0.3f, 0.2f, 0.5f}});\n",
    "yHat.get(new NDIndex(\":, {}\", manager.create(new int[]{0, 2})));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can implement the cross-entropy loss function efficiently with just one line of code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "10"
    }
   },
   "outputs": [],
   "source": [
    "// Cross Entropy only cares about the target class's probability\n",
    "// Get the column index for each row\n",
    "public class LossFunction {\n",
    "    public static NDArray crossEntropy(NDArray yHat, NDArray y) {\n",
    "        // Here, y is not guranteed to be of datatype int or long\n",
    "        // and in our case we know its a float32.\n",
    "        // We must first convert it to int or long(here we choose int)\n",
    "        // before we can use it with NDIndex to \"pick\" indices. \n",
    "        // It also takes in a boolean for returning a copy of the existing NDArray\n",
    "        // but we don't want that so we pass in `false`.\n",
    "        NDIndex pickIndex = new NDIndex()\n",
    "                 .addAllDim(Math.floorMod(-1, yHat.getShape().dimension()))\n",
    "                 .addPickDim(y);\n",
    "        return yHat.get(pickIndex).log().neg();\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification Accuracy\n",
    "\n",
    "Given the predicted probability distribution `yHat`,\n",
    "we typically choose the class with highest predicted probability\n",
    "whenever we must output a *hard* prediction.\n",
    "Indeed, many applications require that we make a choice.\n",
    "Gmail must categorize an email into Primary, Social, Updates, or Forums.\n",
    "It might estimate probabilities internally,\n",
    "but at the end of the day it has to choose one among the categories.\n",
    "\n",
    "When predictions are consistent with the actual category `y`, they are correct.\n",
    "The classification accuracy is the fraction of all predictions that are correct.\n",
    "Although it can be difficult optimize accuracy directly (it is not differentiable),\n",
    "it is often the performance metric that we care most about,\n",
    "and we will nearly always report it when training classifiers.\n",
    "\n",
    "To compute accuracy we do the following:\n",
    "First, we execute `yHat.argMax(1)` where 1 is the axis\n",
    "to gather the predicted classes\n",
    "(given by the indices for the largest entries in each row).\n",
    "The result has the same shape as the variable `y`.\n",
    "Now we just need to check how frequently the two match.\n",
    "Since the equality function `eq()` is datatype-sensitive\n",
    "(e.g., a `float32` and a `float32` are never equal),\n",
    "we also need to convert both to the same type (we pick `int32`).\n",
    "The result is an `NDArray` containing entries of 0 (false) and 1 (true).\n",
    "We then sum the number of true entries and convert the result to a float.\n",
    "Finally, we get the mean by dividing by the number of data points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "11"
    }
   },
   "outputs": [],
   "source": [
    "// Saved in the utils for later use\n",
    "public float accuracy(NDArray yHat, NDArray y) {\n",
    "    // Check size of 1st dimension greater than 1\n",
    "    // to see if we have multiple samples\n",
    "    if (yHat.getShape().size(1) > 1) {\n",
    "        // Argmax gets index of maximum args for given axis 1\n",
    "        // Convert yHat to same dataType as y (int32)\n",
    "        // Sum up number of true entries\n",
    "        return yHat.argMax(1).toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))\n",
    "            .sum().toType(DataType.FLOAT32, false).getFloat();\n",
    "    }\n",
    "    return yHat.toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))\n",
    "        .sum().toType(DataType.FLOAT32, false).getFloat();\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will continue to use the variables `yHat` and `y`\n",
    "defined in the `pick()` function,\n",
    "as the predicted probability distribution and label, respectively.\n",
    "We can see that the first example's prediction category is $2$\n",
    "(the largest element of the row is $0.6$ with an index of $2$),\n",
    "which is inconsistent with the actual label, $0$.\n",
    "The second example's prediction category is $2$\n",
    "(the largest element of the row is $0.5$ with an index of $2$),\n",
    "which is consistent with the actual label, $2$.\n",
    "Therefore, the classification accuracy rate for these two examples is $0.5$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "12"
    }
   },
   "outputs": [],
   "source": [
    "NDArray y = manager.create(new int[]{0,2});\n",
    "accuracy(yHat, y) / y.size();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, we can evaluate the accuracy for model `net` on the dataset\n",
    "(accessed via `dataIterator`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.function.UnaryOperator; \n",
    "import java.util.function.BinaryOperator; "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "13"
    }
   },
   "outputs": [],
   "source": [
    "// Saved in the utils for future use\n",
    "public float evaluateAccuracy(UnaryOperator<NDArray> net, Iterable<Batch> dataIterator) {\n",
    "    Accumulator metric = new Accumulator(2);  // numCorrectedExamples, numExamples\n",
    "    Batch batch = dataIterator.iterator().next();\n",
    "    NDArray X = batch.getData().head();\n",
    "    NDArray y = batch.getLabels().head();\n",
    "    metric.add(new float[]{accuracy(net.apply(X), y), (float)y.size()});\n",
    "    batch.close();\n",
    "\n",
    "    return metric.get(0) / metric.get(1);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here `Accumulator` is a utility class to accumulate sums over multiple numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "14"
    }
   },
   "outputs": [],
   "source": [
    "// Saved in utils for future use        \n",
    "/* Sum a list of numbers over time */\n",
    "public class Accumulator {\n",
    "    float[] data;\n",
    "    \n",
    "    public Accumulator(int n) {\n",
    "        data = new float[n];\n",
    "    }\n",
    "    \n",
    "    /* Adds a set of numbers to the array */\n",
    "    public void add(float[] args) {\n",
    "        for (int i = 0; i < args.length; i++) {\n",
    "            data[i] += args[i];\n",
    "        }\n",
    "    }\n",
    "\n",
    "    /* Resets the array */\n",
    "    public void reset() {\n",
    "        Arrays.fill(data, 0f);\n",
    "    }\n",
    "\n",
    "    /* Returns the data point at the given index */\n",
    "    public float get(int index) {\n",
    "        return data[index];\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because we initialized the `net` model with random weights,\n",
    "the accuracy of this model should be close to random guessing,\n",
    "i.e., $0.1$ for $10$ classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "15"
    }
   },
   "outputs": [],
   "source": [
    "evaluateAccuracy(Net::net, validationSet.getData(manager));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Training\n",
    "\n",
    "The training loop for softmax regression should look strikingly familiar\n",
    "if you read through our implementation\n",
    "of linear regression in :numref:`sec_linear_scratch`.\n",
    "Here we refactor the implementation to make it reusable.\n",
    "First, we define a function to train for one data epoch.\n",
    "Note that `updater()` is a general function to update the model parameters,\n",
    "which accepts the batch size as an argument.\n",
    "Currently, it is a wrapper of `Training.sgd()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "16"
    }
   },
   "outputs": [],
   "source": [
    "@FunctionalInterface\n",
    "public static interface ParamConsumer {\n",
    "     void accept(NDList params, float lr, int batchSize);\n",
    "}\n",
    "\n",
    "public float[] trainEpochCh3(UnaryOperator<NDArray> net, Iterable<Batch> trainIter, BinaryOperator<NDArray> loss, ParamConsumer updater) {\n",
    "    Accumulator metric = new Accumulator(3); // trainLossSum, trainAccSum, numExamples\n",
    "    \n",
    "    // Attach Gradients\n",
    "    for (NDArray param : params) {\n",
    "        param.setRequiresGradient(true);\n",
    "    }\n",
    "    \n",
    "    for (Batch batch : trainIter) {\n",
    "        NDArray X = batch.getData().head();\n",
    "        NDArray y = batch.getLabels().head();\n",
    "        X = X.reshape(new Shape(-1, numInputs));\n",
    "            \n",
    "        try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "            // Minibatch loss in X and y\n",
    "            NDArray yHat = net.apply(X);\n",
    "            NDArray l = loss.apply(yHat, y);\n",
    "            gc.backward(l);  // Compute gradient on l with respect to w and b\n",
    "            metric.add(new float[]{l.sum().toType(DataType.FLOAT32, false).getFloat(), \n",
    "                                   accuracy(yHat, y), \n",
    "                                   (float)y.size()});\n",
    "            gc.close();\n",
    "        }\n",
    "        updater.accept(params, lr, batch.getSize());  // Update parameters using their gradient\n",
    "        \n",
    "        batch.close();\n",
    "    }\n",
    "    // Return trainLoss, trainAccuracy\n",
    "    return new float[]{metric.get(0) / metric.get(2), metric.get(1) / metric.get(2)}; \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before showing the implementation of the training function, we define a utility class that draws data in animation. Again, it aims to simplify the code in later chapters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tech.tablesaw.api.Row;\n",
    "import tech.tablesaw.columns.Column;\n",
    "\n",
    "// Saved in utils\n",
    "/* Animates a graph with real-time data. */\n",
    "class Animator {\n",
    "    private String id; // Id reference of graph(for updating graph)\n",
    "    private Table data; // Data Points\n",
    "    \n",
    "    public Animator() {\n",
    "        id = \"\";\n",
    "        \n",
    "        // Incrementally plot data\n",
    "        data = Table.create(\"Data\")\n",
    "        .addColumns(\n",
    "            FloatColumn.create(\"epoch\", new float[]{}),\n",
    "            FloatColumn.create(\"value\", new float[]{}),\n",
    "            StringColumn.create(\"metric\", new String[]{})\n",
    "        );\n",
    "    }\n",
    "\n",
    "    // Add a single metric to the table\n",
    "    public void add(float epoch, float value, String metric) {\n",
    "        Row newRow = data.appendRow();\n",
    "        newRow.setFloat(\"epoch\", epoch);\n",
    "        newRow.setFloat(\"value\", value);\n",
    "        newRow.setString(\"metric\", metric);\n",
    "    }\n",
    "    \n",
    "    // Add accuracy, train accuracy, and train loss metrics for a given epoch\n",
    "    // Then plot it on the graph\n",
    "    public void add(float epoch, float accuracy, float trainAcc, float trainLoss) {\n",
    "        add(epoch, trainLoss, \"train loss\");\n",
    "        add(epoch, trainAcc, \"train accuracy\");\n",
    "        add(epoch, accuracy, \"test accuracy\");\n",
    "        show();\n",
    "    }\n",
    "    \n",
    "    // Display the graph\n",
    "    public void show() {\n",
    "        if (id.equals(\"\")) {\n",
    "            id = display(LinePlot.create(\"\", data, \"epoch\", \"value\", \"metric\"));\n",
    "            return;\n",
    "        }\n",
    "        update();\n",
    "    }\n",
    "    \n",
    "    // Update the graph\n",
    "    public void update() {\n",
    "        updateDisplay(id, LinePlot.create(\"\", data, \"epoch\", \"value\", \"metric\"));\n",
    "    }\n",
    "    \n",
    "    // Returns the column at the given index\n",
    "    // if it is a float column\n",
    "    // Otherwise returns null\n",
    "    public float[] getY(Table t, int index) {\n",
    "        Column c = t.column(index);\n",
    "        if (c.type() == ColumnType.FLOAT) {\n",
    "            float[] newArray = new float[c.size()];\n",
    "            System.arraycopy(c.asList().toArray(), 0, newArray, 0, c.size());\n",
    "            return newArray;\n",
    "        }\n",
    "        return null;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training function then runs multiple epochs and visualize the training progress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "18"
    }
   },
   "outputs": [],
   "source": [
    "public void trainCh3(UnaryOperator<NDArray> net, Dataset trainDataset, Dataset testDataset, \n",
    "                     BinaryOperator<NDArray> loss, int numEpochs, ParamConsumer updater) \n",
    "                                                            throws IOException, TranslateException {\n",
    "    Animator animator = new Animator();\n",
    "    for (int i = 1; i <= numEpochs; i++) {\n",
    "        float[] trainMetrics = trainEpochCh3(net, trainDataset.getData(manager), loss, updater);\n",
    "        float accuracy = evaluateAccuracy(net, testDataset.getData(manager));\n",
    "        float trainAccuracy = trainMetrics[1];\n",
    "        float trainLoss = trainMetrics[0];\n",
    "        \n",
    "        animator.add(i, accuracy, trainAccuracy, trainLoss);\n",
    "        System.out.printf(\"Epoch %d: Test Accuracy: %f\\n\", i, accuracy);\n",
    "        System.out.printf(\"Train Accuracy: %f\\n\", trainAccuracy);\n",
    "        System.out.printf(\"Train Loss: %f\\n\", trainLoss);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, we use the minibatch stochastic gradient descent\n",
    "to optimize the loss function of the model.\n",
    "Note that the number of epochs (`numEpochs`),\n",
    "and learning rate (`lr`) are both adjustable hyper-parameters.\n",
    "By changing their values, we may be able\n",
    "to increase the classification accuracy of the model.\n",
    "In practice we will want to split our data three ways\n",
    "into training, validation, and test data,\n",
    "using the validation data to choose\n",
    "the best values of our hyper-parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "18"
    }
   },
   "outputs": [],
   "source": [
    "int numEpochs = 5;\n",
    "float lr = 0.1f;\n",
    "\n",
    "public class Updater {\n",
    "    public static void updater(NDList params, float lr, int batchSize) {\n",
    "        Training.sgd(params, lr, batchSize);\n",
    "    }\n",
    "}\n",
    "\n",
    "trainCh3(Net::net, trainingSet, validationSet, LossFunction::crossEntropy, numEpochs, Updater::updater);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction\n",
    "\n",
    "Now that training is complete,\n",
    "our model is ready to classify some images.\n",
    "Given a series of images,\n",
    "we will compare their actual labels\n",
    "(first line of text output)\n",
    "and the model predictions\n",
    "(second line of text output)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "19"
    }
   },
   "outputs": [],
   "source": [
    "// Saved in the FashionMnistUtils class for later use\n",
    "// Number should be < batchSize for this function to work properly\n",
    "public BufferedImage predictCh3(UnaryOperator<NDArray> net, ArrayDataset dataset, int number, NDManager manager) \n",
    "    throws IOException, TranslateException {\n",
    "    int[] predLabels = new int[number];\n",
    "    \n",
    "    Batch batch = dataset.getData(manager).iterator().next();\n",
    "    NDArray X = batch.getData().head();\n",
    "    int[] yHat = net.apply(X).argMax(1).toType(DataType.INT32, false).toIntArray();\n",
    "    for (int i = 0; i < number; i++) {\n",
    "        predLabels[i] = yHat[i];\n",
    "    }\n",
    "    \n",
    "    return FashionMnistUtils.showImages(dataset, predLabels, 28, 28, 4, manager);\n",
    "}\n",
    "\n",
    "predictCh3(Net::net, validationSet, 6, manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "With softmax regression, we can train models for multi-category classification.\n",
    "The training loop is very similar to that in linear regression:\n",
    "retrieve and read data, define models and loss functions,\n",
    "then train models using optimization algorithms.\n",
    "As you will soon find out, most common deep learning models\n",
    "have similar training procedures.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. In this section, we directly implemented the softmax function based on the mathematical definition of the softmax operation. What problems might this cause (hint: try to calculate the size of $\\exp(50)$)?\n",
    "1. The function `crossEntropy()` in this section is implemented according to the definition of the cross-entropy loss function.  What could be the problem with this implementation (hint: consider the domain of the logarithm)?\n",
    "1. What solutions you can think of to fix the two problems above?\n",
    "1. Is it always a good idea to return the most likely label. E.g., would you do this for medical diagnosis?\n",
    "1. Assume that we want to use softmax regression to predict the next word based on some features. What are some problems that might arise from a large vocabulary?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
